{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Draw global dependencies between input and output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class PositionalEncoding(nn.Module):\n",
    "\n",
    "    def __init__(self, d_model, dropout=0.1, max_len=5000):\n",
    "        super(PositionalEncoding, self).__init__()\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "\n",
    "        pe = torch.zeros(max_len, d_model)\n",
    "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "        pe = pe.unsqueeze(0).transpose(0, 1)\n",
    "        self.register_buffer('pe', pe)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.pe[:x.size(0), :]\n",
    "        return self.dropout(x)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransformerModel(nn.Module):\n",
    "\n",
    "    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):\n",
    "        super(TransformerModel, self).__init__()\n",
    "        from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
    "        self.model_type = 'Transformer'\n",
    "        self.src_mask = None\n",
    "        self.pos_encoder = PositionalEncoding(ninp, dropout)\n",
    "        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)\n",
    "        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n",
    "        self.encoder = nn.Embedding(ntoken, ninp)\n",
    "        self.ninp = ninp\n",
    "        self.decoder = nn.Linear(ninp, ntoken)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    def _generate_square_subsequent_mask(self, sz):\n",
    "        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)\n",
    "        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n",
    "        return mask\n",
    "\n",
    "    def init_weights(self):\n",
    "        initrange = 0.1\n",
    "        self.encoder.weight.data.uniform_(-initrange, initrange)\n",
    "        self.decoder.bias.data.zero_()\n",
    "        self.decoder.weight.data.uniform_(-initrange, initrange)\n",
    "\n",
    "    def forward(self, src):\n",
    "        if self.src_mask is None or self.src_mask.size(0) != len(src):\n",
    "            device = src.device\n",
    "            mask = self._generate_square_subsequent_mask(len(src)).to(device)\n",
    "            self.src_mask = mask\n",
    "\n",
    "        src = self.encoder(src) * math.sqrt(self.ninp)\n",
    "        src = self.pos_encoder(src)\n",
    "        output = self.transformer_encoder(src, self.src_mask)\n",
    "        output = self.decoder(output)\n",
    "        return output\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchtext\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "TEXT = torchtext.data.Field(tokenize=get_tokenizer(\"basic_english\"),\n",
    "                            init_token='<sos>',\n",
    "                            eos_token='<eos>',\n",
    "                            lower=True)\n",
    "train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)\n",
    "TEXT.build_vocab(train_txt)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "def batchify(data, bsz):\n",
    "    data = TEXT.numericalize([data.examples[0].text])\n",
    "    # Divide the dataset into bsz parts.\n",
    "    nbatch = data.size(0) // bsz\n",
    "    # Trim off any extra elements that wouldn't cleanly fit (remainders).\n",
    "    data = data.narrow(0, 0, nbatch * bsz)\n",
    "    # Evenly divide the data across the bsz batches.\n",
    "    data = data.view(bsz, -1).t().contiguous()\n",
    "    return data.to(device)\n",
    "\n",
    "batch_size = 20\n",
    "eval_batch_size = 10\n",
    "train_data = batchify(train_txt, batch_size)\n",
    "val_data = batchify(val_txt, eval_batch_size)\n",
    "test_data = batchify(test_txt, eval_batch_size)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "bptt = 35\n",
    "def get_batch(source, i):\n",
    "    seq_len = min(bptt, len(source) - 1 - i)\n",
    "    data = source[i:i+seq_len]\n",
    "    target = source[i+1:i+1+seq_len].view(-1)\n",
    "    return data, target\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "ntokens = len(TEXT.vocab.stoi) # the size of vocabulary\n",
    "emsize = 200 # embedding dimension\n",
    "nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder\n",
    "nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n",
    "nhead = 2 # the number of heads in the multiheadattention models\n",
    "dropout = 0.2 # the dropout value\n",
    "model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "lr = 5.0 # learning rate\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)\n",
    "\n",
    "import time\n",
    "def train():\n",
    "    model.train() # Turn on the train mode\n",
    "    total_loss = 0.\n",
    "    start_time = time.time()\n",
    "    ntokens = len(TEXT.vocab.stoi)\n",
    "    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n",
    "        data, targets = get_batch(train_data, i)\n",
    "        optimizer.zero_grad()\n",
    "        output = model(data)\n",
    "        loss = criterion(output.view(-1, ntokens), targets)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        log_interval = 200\n",
    "        if batch % log_interval == 0 and batch > 0:\n",
    "            cur_loss = total_loss / log_interval\n",
    "            elapsed = time.time() - start_time\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches | '\n",
    "                  'lr {:02.2f} | ms/batch {:5.2f} | '\n",
    "                  'loss {:5.2f} | ppl {:8.2f}'.format(\n",
    "                    epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],\n",
    "                    elapsed * 1000 / log_interval,\n",
    "                    cur_loss, math.exp(cur_loss)))\n",
    "            total_loss = 0\n",
    "            start_time = time.time()\n",
    "\n",
    "def evaluate(eval_model, data_source):\n",
    "    eval_model.eval() # Turn on the evaluation mode\n",
    "    total_loss = 0.\n",
    "    ntokens = len(TEXT.vocab.stoi)\n",
    "    with torch.no_grad():\n",
    "        for i in range(0, data_source.size(0) - 1, bptt):\n",
    "            data, targets = get_batch(data_source, i)\n",
    "            output = eval_model(data)\n",
    "            output_flat = output.view(-1, ntokens)\n",
    "            total_loss += len(data) * criterion(output_flat, targets).item()\n",
    "    return total_loss / (len(data_source) - 1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/nfs_home/sohyongs/anaconda3/envs/fastai2/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:351: UserWarning: To get the last learning rate computed by the scheduler, please use `get_last_lr()`.\n",
      "  \"please use `get_last_lr()`.\", UserWarning)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| epoch   1 |   200/ 2981 batches | lr 5.00 | ms/batch 11.58 | loss  8.04 | ppl  3093.76\n",
      "| epoch   1 |   400/ 2981 batches | lr 5.00 | ms/batch  7.67 | loss  6.79 | ppl   890.87\n",
      "| epoch   1 |   600/ 2981 batches | lr 5.00 | ms/batch  7.67 | loss  6.36 | ppl   577.67\n",
      "| epoch   1 |   800/ 2981 batches | lr 5.00 | ms/batch  7.68 | loss  6.22 | ppl   501.81\n",
      "| epoch   1 |  1000/ 2981 batches | lr 5.00 | ms/batch  7.67 | loss  6.11 | ppl   452.35\n",
      "| epoch   1 |  1200/ 2981 batches | lr 5.00 | ms/batch  7.68 | loss  6.09 | ppl   440.37\n",
      "| epoch   1 |  1400/ 2981 batches | lr 5.00 | ms/batch  7.68 | loss  6.04 | ppl   418.30\n",
      "| epoch   1 |  1600/ 2981 batches | lr 5.00 | ms/batch  7.69 | loss  6.05 | ppl   423.50\n",
      "| epoch   1 |  1800/ 2981 batches | lr 5.00 | ms/batch  7.68 | loss  5.95 | ppl   384.01\n",
      "| epoch   1 |  2000/ 2981 batches | lr 5.00 | ms/batch  7.67 | loss  5.96 | ppl   388.41\n",
      "| epoch   1 |  2200/ 2981 batches | lr 5.00 | ms/batch  7.69 | loss  5.84 | ppl   345.23\n",
      "| epoch   1 |  2400/ 2981 batches | lr 5.00 | ms/batch  7.68 | loss  5.89 | ppl   361.46\n",
      "| epoch   1 |  2600/ 2981 batches | lr 5.00 | ms/batch  7.68 | loss  5.89 | ppl   362.76\n",
      "| epoch   1 |  2800/ 2981 batches | lr 5.00 | ms/batch  7.71 | loss  5.79 | ppl   328.20\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   1 | time: 25.00s | valid loss  5.78 | valid ppl   323.68\n",
      "-----------------------------------------------------------------------------------------\n",
      "| epoch   2 |   200/ 2981 batches | lr 4.51 | ms/batch  7.72 | loss  5.80 | ppl   328.69\n",
      "| epoch   2 |   400/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.77 | ppl   319.16\n",
      "| epoch   2 |   600/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.59 | ppl   269.05\n",
      "| epoch   2 |   800/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.63 | ppl   278.22\n",
      "| epoch   2 |  1000/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.58 | ppl   265.54\n",
      "| epoch   2 |  1200/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.61 | ppl   272.87\n",
      "| epoch   2 |  1400/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.62 | ppl   275.73\n",
      "| epoch   2 |  1600/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.66 | ppl   285.95\n",
      "| epoch   2 |  1800/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.58 | ppl   265.17\n",
      "| epoch   2 |  2000/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.61 | ppl   274.50\n",
      "| epoch   2 |  2200/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.50 | ppl   244.99\n",
      "| epoch   2 |  2400/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.57 | ppl   261.93\n",
      "| epoch   2 |  2600/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.58 | ppl   265.83\n",
      "| epoch   2 |  2800/ 2981 batches | lr 4.51 | ms/batch  7.68 | loss  5.50 | ppl   245.18\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   2 | time: 24.20s | valid loss  5.55 | valid ppl   256.07\n",
      "-----------------------------------------------------------------------------------------\n",
      "| epoch   3 |   200/ 2981 batches | lr 4.29 | ms/batch  7.72 | loss  5.54 | ppl   253.90\n",
      "| epoch   3 |   400/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.54 | ppl   255.43\n",
      "| epoch   3 |   600/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.35 | ppl   210.76\n",
      "| epoch   3 |   800/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.40 | ppl   222.16\n",
      "| epoch   3 |  1000/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.36 | ppl   212.88\n",
      "| epoch   3 |  1200/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.40 | ppl   221.11\n",
      "| epoch   3 |  1400/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.42 | ppl   226.10\n",
      "| epoch   3 |  1600/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.46 | ppl   234.92\n",
      "| epoch   3 |  1800/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.40 | ppl   220.61\n",
      "| epoch   3 |  2000/ 2981 batches | lr 4.29 | ms/batch  7.67 | loss  5.42 | ppl   226.96\n",
      "| epoch   3 |  2200/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.30 | ppl   201.13\n",
      "| epoch   3 |  2400/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.38 | ppl   217.03\n",
      "| epoch   3 |  2600/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.40 | ppl   221.78\n",
      "| epoch   3 |  2800/ 2981 batches | lr 4.29 | ms/batch  7.68 | loss  5.33 | ppl   206.05\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   3 | time: 24.20s | valid loss  5.52 | valid ppl   250.60\n",
      "-----------------------------------------------------------------------------------------\n"
     ]
    }
   ],
   "source": [
    "best_val_loss = float(\"inf\")\n",
    "epochs = 3 # The number of epochs\n",
    "best_model = None\n",
    "\n",
    "for epoch in range(1, epochs + 1):\n",
    "    epoch_start_time = time.time()\n",
    "    train()\n",
    "    val_loss = evaluate(model, val_data)\n",
    "    print('-' * 89)\n",
    "    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '\n",
    "          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),\n",
    "                                     val_loss, math.exp(val_loss)))\n",
    "    print('-' * 89)\n",
    "\n",
    "    if val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "        best_model = model\n",
    "\n",
    "    scheduler.step()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=========================================================================================\n",
      "| End of training | test loss  5.43 | test ppl   228.08\n",
      "=========================================================================================\n"
     ]
    }
   ],
   "source": [
    "test_loss = evaluate(best_model, test_data)\n",
    "print('=' * 89)\n",
    "print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(\n",
    "    test_loss, math.exp(test_loss)))\n",
    "print('=' * 89)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([28785, 200])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.encoder.weight.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TransformerModel(\n",
       "  (pos_encoder): PositionalEncoding(\n",
       "    (dropout): Dropout(p=0.2, inplace=False)\n",
       "  )\n",
       "  (transformer_encoder): TransformerEncoder(\n",
       "    (layers): ModuleList(\n",
       "      (0): TransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (out_proj): Linear(in_features=200, out_features=200, bias=True)\n",
       "        )\n",
       "        (linear1): Linear(in_features=200, out_features=200, bias=True)\n",
       "        (dropout): Dropout(p=0.2, inplace=False)\n",
       "        (linear2): Linear(in_features=200, out_features=200, bias=True)\n",
       "        (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Dropout(p=0.2, inplace=False)\n",
       "        (dropout2): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "      (1): TransformerEncoderLayer(\n",
       "        (self_attn): MultiheadAttention(\n",
       "          (out_proj): Linear(in_features=200, out_features=200, bias=True)\n",
       "        )\n",
       "        (linear1): Linear(in_features=200, out_features=200, bias=True)\n",
       "        (dropout): Dropout(p=0.2, inplace=False)\n",
       "        (linear2): Linear(in_features=200, out_features=200, bias=True)\n",
       "        (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
       "        (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)\n",
       "        (dropout1): Dropout(p=0.2, inplace=False)\n",
       "        (dropout2): Dropout(p=0.2, inplace=False)\n",
       "      )\n",
       "    )\n",
       "  )\n",
       "  (encoder): Embedding(28785, 200)\n",
       "  (decoder): Linear(in_features=200, out_features=28785, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
